/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

syntax = "proto3";

package xla.ifrt;

import "google/protobuf/any.proto";
import "tensorflow/compiler/xla/xla_data.proto";

// Represents a host callback in an XLA computation.
//
// XLA computation may use XLA send/recv to represent communication between the
// host and the device. This can be used to implement "host callbacks", where
// host-side computation is invoked in the middle of XLA computation. This
// message contains information that is necessary to instantiate host callbacks
// that are marshalled from the client.
//
// Modeled after `xla::HostCallback`.
message XlaHostCallbackProto {
  message ArgInfo {
    // The channel id associated with this value in HLO. Declared as `uint32`
    // even though `xla::HostCallbackArgInfo::channel_id` is `uint16_t` because
    // protobuf doesn't have a 16-bit integer type.
    uint32 channel_id = 1;

    // The host shape for the value.
    xla.ShapeProto shape = 2;
  }

  // The metadata (e.g. channel_id, shape) for the operands and results.
  repeated ArgInfo operands = 1;
  repeated ArgInfo results = 2;

  // Serialized host callback.
  google.protobuf.Any serialized_callback = 3;

  // See comment for PJRT
  // ExecuteOptions::use_major_to_minor_data_layout_for_callbacks.
  bool use_major_to_minor_data_layout_for_callbacks = 4;
}
